from subprocess import PIPE, Popen


MSA_GAP_PROP_THRESHOLD = 0.05


def parse_aln_out(out):
    lines = out.split('\n')
    seqs = []
    header = ''
    for line in lines:
        line = line.strip()
        if line == '': continue
        if line[0] == '>':
            if header == '':
                header = line.replace('>','')
            else:
                yield header, ''.join(seqs)
                seqs = []
                header = line.replace('>','')
        else:
            seqs.append(line)
    yield header, ''.join(seqs)


def msa_mafft(seqs):

    if isinstance(seqs, list):
        mafft_stdin = '\n'.join(['>{}\n{}'.format(i, x.lower()) for i, x in enumerate(seqs)])
    elif isinstance(seqs, dict):
        mafft_stdin = '\n'.join(['>{}\n{}'.format(k, v.lower()) for k, v in seqs.items()])
    elif isinstance(seqs, (str,)):
        mafft_stdin = seqs
    else:
        raise Exception('Unexpected type for param seqs of "{}"'.format(type(seqs)))
    p = Popen(['mafft', '-'], stdin=PIPE, stdout=PIPE, stderr=PIPE)
    stdout, stderr = p.communicate(input=mafft_stdin.encode())
    if isinstance(stdout, bytes):
        stdout = stdout.decode()

    if len(stdout) > 0 and stdout[0] == '>':
        if isinstance(seqs, list):
            return [s for h,s in parse_aln_out(stdout)]
        elif isinstance(seqs, (dict, str)):
            return {h:s for h,s in parse_aln_out(stdout)}
    else:
        raise Exception('MSA not generated by MAFFT stdout=\n{}\n\nstderr=\n{}\n'.format(stdout, stderr))



def msa_ref_vs_novel(ref_seq, novel_seq):
    str_format_input_fasta = """>ref
{}
>novel
{}
"""
    input_fasta = str_format_input_fasta.format(ref_seq.lower(), novel_seq.lower())
    msa_out_dict = msa_mafft(input_fasta)
    assert 'ref' in msa_out_dict
    assert 'novel' in msa_out_dict
    return msa_out_dict['ref'], msa_out_dict['novel']


def number_gapped_ungapped(aln1, aln2):
    ungapped = 0
    total_gapped = 0
    for c1, c2 in zip(aln1, aln2):
        if c1 == '-' and c2 == '-':
            continue
        if c1 == '-':
            total_gapped += 1
        elif c2 == '-':
            total_gapped += 1
        else:
            ungapped += 1

    return total_gapped, ungapped